{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 451
    },
    "colab_type": "code",
    "id": "IZ0rjD0mvNHo",
    "outputId": "b90c5729-cdf7-4cb6-d200-011c5277547f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: transformers in /usr/local/lib/python3.6/dist-packages (2.6.0)\n",
      "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.21.0)\n",
      "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.38.0)\n",
      "Requirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers) (1.12.26)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.2)\n",
      "Requirement already satisfied: sacremoses in /usr/local/lib/python3.6/dist-packages (from transformers) (0.0.38)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)\n",
      "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.6/dist-packages (from transformers) (0.1.85)\n",
      "Requirement already satisfied: tokenizers==0.5.2 in /usr/local/lib/python3.6/dist-packages (from transformers) (0.5.2)\n",
      "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2019.11.28)\n",
      "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.8)\n",
      "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)\n",
      "Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.3.3)\n",
      "Requirement already satisfied: botocore<1.16.0,>=1.15.26 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (1.15.26)\n",
      "Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.9.5)\n",
      "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)\n",
      "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.14.1)\n",
      "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.1)\n",
      "Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.26->boto3->transformers) (2.8.1)\n",
      "Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.26->boto3->transformers) (0.15.2)\n",
      "fatal: destination path 'transformers' already exists and is not an empty directory.\n"
     ]
    }
   ],
   "source": [
    "!pip install transformers\n",
    "!git clone https://github.com/huggingface/transformers.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vHWQYKDGRew6"
   },
   "source": [
    "# BertForSequenceClassification情感分析"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7nITXqUoqxp4"
   },
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "EmmoVpH4tlo9"
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "import pandas as pd\n",
    "\n",
    "class SentimentDataset(Dataset):\n",
    "\tdef __init__(self, path_to_file):\n",
    "\t\tself.dataset = pd.read_csv(path_to_file, sep=\"\\t\", names=[\"text\", \"label\"])\n",
    "\tdef __len__(self):\n",
    "\t\treturn len(self.dataset)\n",
    "\tdef __getitem__(self, idx):\n",
    "\t\ttext = self.dataset.loc[idx, \"text\"]\n",
    "\t\tlabel = self.dataset.loc[idx, \"label\"]\n",
    "\t\tsample = {\"text\": text, \"label\": label}\n",
    "\t\treturn sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 64
    },
    "colab_type": "code",
    "id": "7NW9E58avD6K",
    "outputId": "e823f347-08bc-4a2f-e242-b8c84617a3a7"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<p style=\"color: red;\">\n",
       "The default version of TensorFlow in Colab will switch to TensorFlow 2.x on the 27th of March, 2020.<br>\n",
       "We recommend you <a href=\"https://www.tensorflow.org/guide/migrate\" target=\"_blank\">upgrade</a> now\n",
       "or ensure your notebook will continue to use TensorFlow 1.x via the <code>%tensorflow_version 1.x</code> magic:\n",
       "<a href=\"https://colab.research.google.com/notebooks/tensorflow_version.ipynb\" target=\"_blank\">more info</a>.</p>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from transformers import BertTokenizer\n",
    "from transformers import BertForSequenceClassification, BertModel\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import AdamW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "PZblz3sevYil"
   },
   "outputs": [],
   "source": [
    "# 超参数\n",
    "hidden_dropout_prob = 0.5\n",
    "num_labels = 2\n",
    "learning_rate = 1e-5\n",
    "weight_decay = 1e-2\n",
    "epochs = 5\n",
    "max_len = 100\n",
    "batch_size = 16\n",
    "class_num = 2\n",
    "\n",
    "base_path = \"/content/drive/My Drive/Colab Notebooks/\"\n",
    "vocab_file = base_path + \"PyTorch_Pretrained_Model/chinese_wwm_pytorch/vocab.txt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "id": "r5Nm0G0HwMtF",
    "outputId": "d6448a99-99cb-4f8a-8314-20b3fe2940e8"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Tesla P100-PCIE-16GB'"
      ]
     },
     "execution_count": 8,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cuda.get_device_name(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "4X4oXRVDwE__"
   },
   "outputs": [],
   "source": [
    "# 使用GPU\n",
    "# 然后通过model.to(device)的方式使用\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "W3XdNer6MQyn"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Vdjl8dzSwUqC"
   },
   "outputs": [],
   "source": [
    "data_path = base_path + \"/data/sentiment/\"\n",
    "# 加载数据集\n",
    "sentiment_train_set = SentimentDataset(data_path + \"sentiment.train.data\")\n",
    "sentiment_train_loader = DataLoader(sentiment_train_set, batch_size=batch_size, shuffle=True, num_workers=2)\n",
    "\n",
    "sentiment_valid_set = SentimentDataset(data_path + \"sentiment.valid.data\")\n",
    "sentiment_valid_loader = DataLoader(sentiment_valid_set, batch_size=batch_size, shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "yZ6GFtFOwh_W"
   },
   "outputs": [],
   "source": [
    "# 加载模型\n",
    "model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "hc1_wjmOxHQZ"
   },
   "outputs": [],
   "source": [
    "# 定义优化器和损失函数\n",
    "# Prepare optimizer and schedule (linear warmup and decay)\n",
    "no_decay = ['bias', 'LayerNorm.weight']\n",
    "optimizer_grouped_parameters = [\n",
    "        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},\n",
    "        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "]\n",
    "#optimizer = AdamW(model.parameters(), lr=learning_rate)\n",
    "optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "tokenizer = BertTokenizer(vocab_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NzULEB8MyzR7"
   },
   "outputs": [],
   "source": [
    "def convert_text_to_ids(tokenizer, text, max_len=100):\n",
    "\tif isinstance(text, str):\n",
    "\t\ttokenized_text = tokenizer.encode_plus(text, max_length=max_len, add_special_tokens=True)\n",
    "\t\tinput_ids = tokenized_text[\"input_ids\"]\n",
    "\t\ttoken_type_ids = tokenized_text[\"token_type_ids\"]\n",
    "\telif isinstance(text, list):\n",
    "\t\tinput_ids = []\n",
    "\t\ttoken_type_ids = []\n",
    "\t\tfor t in text:\n",
    "\t\t\ttokenized_text = tokenizer.encode_plus(t, max_length=max_len, add_special_tokens=True)\n",
    "\t\t\tinput_ids.append(tokenized_text[\"input_ids\"])\n",
    "\t\t\ttoken_type_ids.append(tokenized_text[\"token_type_ids\"])\n",
    "\telse:\n",
    "\t\tprint(\"Unexpected input\")\n",
    "\treturn input_ids, token_type_ids\n",
    "\n",
    "\n",
    "def seq_padding(tokenizer, X):\n",
    "\t# 需要 LongTensor\n",
    "\tpad_id = tokenizer.convert_tokens_to_ids(\"[PAD]\")\n",
    "\tif len(X) <= 1:\n",
    "\t\treturn torch.tensor(X, dtype=torch.long)\n",
    "\tL = [len(x) for x in X]\n",
    "\tML = max(L)\n",
    "\tX = torch.tensor([x + [pad_id] * (ML - len(x)) if len(x) < ML else x for x in X], dtype=torch.long)\n",
    "\treturn X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NIkFb-Hey1ov"
   },
   "outputs": [],
   "source": [
    "def train(model, iterator, optimizer, criterion, device):\n",
    "\tmodel.to(device)\n",
    "\tmodel.train()\n",
    "\tepoch_loss = 0\n",
    "\tepoch_acc = 0\n",
    "\tfor i, batch in enumerate(iterator):\n",
    "\t\tlabel = batch[\"label\"]\n",
    "\t\ttext = batch[\"text\"]\n",
    "\t\tinput_ids, token_type_ids = convert_text_to_ids(tokenizer, text)\n",
    "\t\tinput_ids = seq_padding(tokenizer, input_ids)\n",
    "\t\ttoken_type_ids = seq_padding(tokenizer, token_type_ids)\n",
    "\t\t# 标签形状为 (batch_size, 1) \n",
    "\t\tlabel = label.unsqueeze(1)\n",
    "\t\t# 梯度清零\n",
    "\t\toptimizer.zero_grad()\n",
    "\t\t# 迁移到GPU\n",
    "\t\tinput_ids, token_type_ids, label = input_ids.to(device), token_type_ids.to(device), label.to(device)\n",
    "\t\t# (loss), logits, (hidden_states), (attentions)\n",
    "\t\t# (hidden_states), (attentions) 不一定存在\n",
    "\t\toutput = model(input_ids=input_ids, token_type_ids=token_type_ids, labels=label)\n",
    "\t\ty_pred_prob = output[1]\n",
    "\t\ty_pred_label = y_pred_prob.argmax(dim=1)\n",
    "\t\t# 计算loss\n",
    "\t\t#loss = criterion(y_pred_prob.view(-1, 2), label.view(-1))\n",
    "\t\tloss = output[0]\n",
    "\t\t# 计算acc\n",
    "\t\tacc = ((y_pred_label == label.view(-1)).sum()).item()\n",
    "\t\t# 反向传播\n",
    "\t\tloss.backward()\n",
    "\t\toptimizer.step()\n",
    "\t\t# epoch 中的 loss 和 acc 累加\n",
    "\t\tepoch_loss += loss.item()\n",
    "\t\tepoch_acc += acc\n",
    "\t\tif i % 100 == 0:\n",
    "\t\t\tprint(\"current loss:\", epoch_loss / (i+1), \"\\t\", \"current acc:\", epoch_acc / ((i+1)*len(label)))\n",
    "\t# return epoch_loss / len(iterator), epoch_acc / (len(iterator) * iterator.batch_size)\n",
    "\treturn epoch_loss / len(iterator), epoch_acc / len(iterator.dataset.dataset)\n",
    "\n",
    "def evaluate(model, iterator, criterion, device):\n",
    "\tmodel.to(device)\n",
    "\tmodel.eval()\n",
    "\tepoch_loss = 0\n",
    "\tepoch_acc = 0\n",
    "\twith torch.no_grad():\n",
    "\t\tfor _, batch in enumerate(iterator):\n",
    "\t\t\tlabel = batch[\"label\"]\n",
    "\t\t\ttext = batch[\"text\"]\n",
    "\t\t\tinput_ids, token_type_ids = convert_text_to_ids(tokenizer, text)\n",
    "\t\t\tinput_ids = seq_padding(tokenizer, input_ids)\n",
    "\t\t\ttoken_type_ids = seq_padding(tokenizer, token_type_ids)\n",
    "\t\t\t# 标签形状为 (batch_size, 1) \n",
    "\t\t\tlabel = label.unsqueeze(1)\n",
    "\t\t\t# 迁移到GPU\n",
    "\t\t\tinput_ids, token_type_ids, label = input_ids.to(device), token_type_ids.to(device), label.to(device)\n",
    "\t\t\toutput = model(input_ids=input_ids, token_type_ids=token_type_ids, labels=label)\n",
    "\t\t\ty_pred_label = output[1].argmax(dim=1)\n",
    "\t\t\tloss = output[0]\n",
    "\t\t\tacc = ((y_pred_label == label.view(-1)).sum()).item()\n",
    "\t\t\tepoch_loss += loss.item()\n",
    "\t\t\tepoch_acc += acc\n",
    "\t# return epoch_loss / len(iterator), epoch_acc / (len(iterator) * iterator.batch_size)\n",
    "\treturn epoch_loss / len(iterator), epoch_acc / len(iterator.dataset.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "id": "qaTC2OtVy4Gx",
    "outputId": "c1ca3cdd-0686-473b-b410-4e7f1387c65f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "current loss: 0.7206894159317017 \t current acc: 0.375\n",
      "current loss: 0.6639068952881464 \t current acc: 0.5928217821782178\n",
      "current loss: 0.659623192762261 \t current acc: 0.605410447761194\n",
      "current loss: 0.6472823614891977 \t current acc: 0.6243770764119602\n",
      "current loss: 0.6229137607717752 \t current acc: 0.6510286783042394\n",
      "current loss: 0.6070191398887577 \t current acc: 0.6681636726546906\n",
      "current loss: 0.5884566437236084 \t current acc: 0.6857321131447587\n",
      "current loss: 0.5727972561086976 \t current acc: 0.7018544935805991\n",
      "current loss: 0.5588198389937071 \t current acc: 0.7135611735330837\n",
      "current loss: 0.5455629944007484 \t current acc: 0.7237791342952276\n",
      "current loss: 0.5343364701493756 \t current acc: 0.7319555444555444\n",
      "train loss:  0.5273547310449346 \t train acc: 0.736623314923689\n",
      "valid loss:  0.4159237858133786 \t valid acc: 0.8180956892468024\n",
      "\n",
      "current loss: 0.2922477126121521 \t current acc: 0.875\n",
      "current loss: 0.3790179410635835 \t current acc: 0.8366336633663366\n",
      "current loss: 0.3748158735422353 \t current acc: 0.8367537313432836\n",
      "current loss: 0.36301204443373947 \t current acc: 0.8453073089700996\n",
      "current loss: 0.3520422117445534 \t current acc: 0.8524002493765586\n",
      "current loss: 0.3459386094631311 \t current acc: 0.8566616766467066\n",
      "current loss: 0.34129169045292 \t current acc: 0.8591930116472546\n",
      "current loss: 0.3296213327042131 \t current acc: 0.8650142653352354\n",
      "current loss: 0.32691875075486565 \t current acc: 0.866729088639201\n",
      "current loss: 0.32235985335694434 \t current acc: 0.8686875693673696\n",
      "current loss: 0.3177544258415818 \t current acc: 0.870004995004995\n",
      "train loss:  0.31593618676745655 \t train acc: 0.8709543322050003\n",
      "valid loss:  0.276584953041465 \t valid acc: 0.8905731880625296\n",
      "\n",
      "current loss: 0.07927422225475311 \t current acc: 1.0\n",
      "current loss: 0.2317083547315975 \t current acc: 0.9096534653465347\n",
      "current loss: 0.22385860447637476 \t current acc: 0.9144900497512438\n",
      "current loss: 0.22587976112120176 \t current acc: 0.9132059800664452\n",
      "current loss: 0.22374588512162913 \t current acc: 0.9131857855361596\n",
      "current loss: 0.21980370914269826 \t current acc: 0.9160429141716567\n",
      "current loss: 0.2199315383174951 \t current acc: 0.9148294509151415\n",
      "current loss: 0.21981334195242458 \t current acc: 0.9148537803138374\n",
      "current loss: 0.21821277354205593 \t current acc: 0.9156523096129837\n",
      "current loss: 0.21860016676102176 \t current acc: 0.9147475027746947\n",
      "current loss: 0.21680165681760985 \t current acc: 0.9158966033966034\n",
      "train loss:  0.2169776937171661 \t train acc: 0.9157313379654374\n",
      "valid loss:  0.23261333691577116 \t valid acc: 0.916153481762198\n",
      "\n",
      "current loss: 0.1601882129907608 \t current acc: 0.875\n",
      "current loss: 0.13249680805619402 \t current acc: 0.9573019801980198\n",
      "current loss: 0.15065006368714778 \t current acc: 0.9496268656716418\n",
      "current loss: 0.15459052267858753 \t current acc: 0.946843853820598\n",
      "current loss: 0.15110565348530647 \t current acc: 0.9476309226932669\n",
      "current loss: 0.14971649017698035 \t current acc: 0.9474800399201597\n",
      "current loss: 0.15668147178785177 \t current acc: 0.9434276206322796\n",
      "current loss: 0.1543601096160019 \t current acc: 0.9440977175463623\n",
      "current loss: 0.15314887960054233 \t current acc: 0.9442883895131086\n",
      "current loss: 0.15253860828829527 \t current acc: 0.9436736958934517\n",
      "current loss: 0.15207287311911225 \t current acc: 0.9442432567432567\n",
      "train loss:  0.15112965823932137 \t train acc: 0.9444147514698022\n",
      "valid loss:  0.29913669106352964 \t valid acc: 0.9043107531975367\n",
      "\n",
      "current loss: 0.15637311339378357 \t current acc: 0.9375\n",
      "current loss: 0.09597253295971025 \t current acc: 0.9647277227722773\n",
      "current loss: 0.09822458206717648 \t current acc: 0.9645522388059702\n",
      "current loss: 0.10030069702190418 \t current acc: 0.9636627906976745\n",
      "current loss: 0.0959477037971751 \t current acc: 0.96571072319202\n",
      "current loss: 0.09663735423422978 \t current acc: 0.9654441117764471\n",
      "current loss: 0.09827782687872103 \t current acc: 0.9651622296173045\n",
      "current loss: 0.100696504615598 \t current acc: 0.9642475035663338\n",
      "current loss: 0.10024348035984941 \t current acc: 0.9645755305867666\n",
      "current loss: 0.10211012197180871 \t current acc: 0.9640677025527192\n",
      "current loss: 0.10165044539175429 \t current acc: 0.964035964035964\n",
      "train loss:  0.10147943017086242 \t train acc: 0.9637745709365164\n",
      "valid loss:  0.26174746668248466 \t valid acc: 0.9156797726196115\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(epochs):\n",
    "\ttrain_loss, train_acc = train(model, sentiment_train_loader, optimizer, criterion, device)\n",
    "\tprint(\"train loss: \", train_loss, \"\\t\", \"train acc:\", train_acc)\n",
    "\tvalid_loss, valid_acc = evaluate(model, sentiment_valid_loader, criterion, device)\n",
    "\tprint(\"valid loss: \", valid_loss, \"\\t\", \"valid acc:\", valid_acc, end=\"\\n\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 71
    },
    "colab_type": "code",
    "id": "0FwGSeWj3JS7",
    "outputId": "3e01be31-8822-46aa-877f-a1f3257bd343"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('./saved_tokenizer/vocab.txt',\n",
       " './saved_tokenizer/special_tokens_map.json',\n",
       " './saved_tokenizer/added_tokens.json')"
      ]
     },
     "execution_count": 20,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "saved_model = \"./saved_model\"\n",
    "saved_tokenizer = \"./saved_tokenizer\"\n",
    "os.makedirs(saved_model)\n",
    "os.makedirs(saved_tokenizer)\n",
    "model.save_pretrained(saved_model)\n",
    "tokenizer.save_pretrained(saved_tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "PyTorch.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
